import torch
from tqdm import tqdm

from .base import Trainer


class LogitClip(Trainer):
    def __init__(
        self,
        config,
        model,
        logger,
        train_set,
        test_set,
        criterion,
        optimizer,
        scheduler=None,
        val_set=None,
    ):
        super().__init__(
            config,
            model,
            logger,
            train_set,
            test_set,
            criterion,
            optimizer,
            scheduler,
            val_set,
        )
        self.temp = self.config["train"]["temp"]
        self.lp = self.config["train"]["lp"]

    def train_batch(
        self,
        outputs,
        targets,
    ):
        logits = outputs
        delta = 1 / self.temp
        norms = torch.norm(logits, p=self.lp, dim=-1, keepdim=True) + 1e-7
        logits_norm = torch.div(logits, norms) * delta
        clip = (norms > self.temp).expand(-1, logits.shape[-1])
        logits_final = torch.where(clip, logits_norm, logits)
        loss = self.criterion(logits_final, targets)
        return loss

    def run(self):
        print("==> Start training..")
        best_acc = 0.0
        for cur_epoch in range(self.epoch):
            self.model.train()
            loss_avg = 0.0
            epoch_loss, epoch_correct, total_num = 0.0, 0.0, 0.0
            with tqdm(self.train_loader, unit="batch") as tepoch:
                for data in tepoch:
                    tepoch.set_description(f"Epoch {cur_epoch}")
                    inputs, labels, attributes, idx = self.prepare_data(data)
                    self.optimizer.zero_grad()
                    outputs = self.model(inputs)
                    loss = self.train_batch(outputs, labels)
                    loss.backward()
                    self.optimizer.step()
                    # loss_avg = loss_avg * 0.8 + float(loss) * 0.2 # exponential moving average
                    correct = (outputs.argmax(1) == labels).sum().item()
                    tepoch.set_postfix(
                        loss=loss.item(),
                        accuracy=100.0 * correct / inputs.size(0),
                        lr=self.get_lr(),
                    )
                    epoch_loss += loss
                    epoch_correct += correct
                    total_num += inputs.size(0)
                    self.global_iter += 1
                    if (
                        self.global_iter % self.config["general"]["logger"]["frequency"]
                        == 0
                    ):
                        self.logger.info(
                            f"[{cur_epoch}]/[{self.epoch}], Global Iter: {self.global_iter}, Loss: {loss:.4f}, Acc: {100.0 * correct / inputs.size(0):.4f}, lr: {self.get_lr():.6f}",
                            {
                                "cur_epoch": cur_epoch,
                                "iter": self.global_iter,
                                "loss": loss.item(),
                                "Accuracy": 100.0 * correct / inputs.size(0),
                                "lr": self.get_lr(),
                            },
                        )
            epoch_loss /= total_num
            epoch_acc = epoch_correct / total_num * 100.0
            if self.val_set:
                _ = self.evaluate(val=True)
            test_acc = self.evaluate(val=False)

            if test_acc > best_acc:
                best_acc = test_acc
                self.save_best_model()
            print(
                f"Epoch: {cur_epoch}, Loss: {epoch_loss:.6f}, Train Acc: {epoch_acc:.4f}, Test Acc: {test_acc:.4f}, Best Test Acc: {best_acc:.4f}"
            )
            self.logger.info(
                f"[{cur_epoch}]/[{self.epoch}], Loss: {epoch_loss:.6f}, Train Acc: {epoch_acc:.4f}, Test Acc: {test_acc:.4f}, Best Test Acc: {best_acc:.4f}",
                {
                    "test_epoch": cur_epoch,
                    "loss": epoch_loss.item(),
                    "Train Acc": epoch_acc,
                    "Test Acc": test_acc,
                    "Best Test Acc": best_acc,
                },
            )

            if self.scheduler:
                self.scheduler.step()
            if cur_epoch % self.save_every_epoch == 0:
                self.save_model(f"{cur_epoch}")
            self.save_last_model()
